# -*- coding: utf-8 -*-
"""Experience Class."""
from __future__ import annotations

import pickle
import uuid
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Dict, List, Literal, Optional

import torch
from datasets import Dataset
from torch import Tensor


@dataclass
class EID:
    """Experience ID class to uniquely identify an experience.

    To enable the full functionality of the experience grouping, user should manually set the `run` and `step` fields in custom workflows.
    """

    # TODO: do we need to add project/name here to make it unique across different projects?
    # Batch number, e.g., the explorer step num
    # Automatically set by the workflow runner
    batch: int = 0
    # Task number, e.g., the task sequence in the batch, the first task in the batch has task=0
    # Automatically set by the workflow runner
    task: int = 0
    # Run id, e.g., the first run in the task has run=0
    # User should set this field in custom workflows when creating experiences
    run: int = 0
    # Step number when running the task, e.g., the first step in the task has step=0
    # User should set this field in custom workflows when creating experiences
    step: int = 0
    suffix: str = field(
        default_factory=lambda: uuid.uuid4().hex[:6]
    )  # Unique identifier suffix, e.g., a UUID

    @property
    def uid(self) -> str:
        """An unique identifier for the experience."""
        return f"{self.batch}/{self.task}/{self.run}/{self.step}/{self.suffix}"

    @property
    def sid(self) -> str:
        """Step ID of the experience.

        For example, experiences generated by all runs of a same task at the same step will have the same sid.
        """
        return f"{self.batch}/{self.task}/{self.step}"

    @property
    def rid(self) -> str:
        """Run ID of the experience.

        For example, experiences generated by one run of a task at all steps will have the same run_id.
        """
        return f"{self.batch}/{self.task}/{self.run}"

    @property
    def tid(self) -> str:
        """Task ID for the experience.

        For example, experiences generated by a all run of a same task in GRPO-like algorithms will have the same tid.
        """
        return f"{self.batch}/{self.task}"

    def __str__(self):
        return self.uid

    def __repr__(self):
        return f"EID(batch={self.batch}, task={self.task}, run={self.run}, step={self.step}, uuid={self.suffix})"

    def to_dict(self) -> dict:
        """Convert the EID to a dictionary."""
        return {
            "batch": self.batch,
            "task": self.task,
            "run": self.run,
            "step": self.step,
            "suffix": self.suffix,
        }


@dataclass(frozen=True)
class CustomField:
    """Custom field for Experiences.

    This is used to store additional information into the Experiences class.
    """

    source_field: str  # The source field name in the Experience.info
    destination_field: str  # The destination field name in the Experiences class
    data_type: torch.dtype  # The data type of the field, e.g., torch.float32, torch.int64, etc.


@dataclass
class Experience:
    eid: EID = field(default_factory=EID)  # Unique identifier for the experience
    tokens: Optional[Tensor] = None  # [seq_length]
    prompt_length: int = 1  # Length of the prompt in tokens, used for generating attention masks
    logprobs: Optional[Tensor] = None  # [resp_length]
    reward: Optional[float] = None
    advantages: Optional[Tensor] = None  # [resp_length]
    returns: Optional[Tensor] = None  # [resp_length]
    info: dict = field(
        default_factory=dict
    )  # Additional information about the experience, can also be used to store custom fields
    metrics: dict[str, float] = field(
        default_factory=dict
    )  # Metrics associated with the experience, directly used by the monitor

    # for single-turn experiences
    response_text: Optional[str] = None  # Text of the response
    prompt_text: Optional[str] = None  # Text of the prompt

    # for multi-turn experiences
    # Action mask indicates which tokens are generated by the model
    action_mask: Optional[Tensor] = None  # [resp_length]
    messages: Optional[List[dict]] = None  # List of messages
    tools: Optional[List[dict]] = None

    # for dpo experiences
    chosen: Optional[Tensor] = None  # Token ids of the chosen response [resp_length]
    rejected: Optional[Tensor] = None  # Token ids of the rejected response [resp_length]
    chosen_messages: Optional[List[dict]] = None  # Chosen message list (Include prompt message)
    rejected_messages: Optional[List[dict]] = None  # Rejected message list (Include prompt message)

    # for multi-modal data
    multi_modal_inputs: Optional[Dict[str, Tensor]] = None  # Multi-modal inputs for verl trainer

    def __init__(  # noqa: C901
        self,
        *,
        eid=None,
        tokens,
        logprobs=None,
        reward=None,
        advantages=None,
        returns=None,
        info=None,
        metrics=None,
        prompt_length=1,
        response_text=None,
        prompt_text=None,
        action_mask=None,
        messages=None,
        tools=None,
        chosen=None,
        rejected=None,
        chosen_messages=None,
        rejected_messages=None,
        multi_modal_inputs=None,
    ):
        if action_mask is not None:
            experience_type = "multi_turn"
        elif chosen is not None and rejected is not None:
            experience_type = "dpo"
        else:
            experience_type = "single_turn"

        if experience_type == "single_turn":
            assert (
                prompt_length > 0
            ), "Prompt length must be greater than 0 for single-turn experiences."
            assert (
                len(tokens) > prompt_length
            ), f"Token ids must be longer than the prompt length. Got len(tokens)={len(tokens)}, prompt_length={prompt_length}."
            action_mask = torch.ones(len(tokens) - prompt_length, dtype=torch.bool)
        elif experience_type == "dpo":
            prompt_length = len(tokens)
        if eid is None:
            self.eid = EID()
        elif isinstance(eid, dict):
            self.eid = EID(**eid)
        else:
            self.eid = eid
        if isinstance(tokens, list):
            tokens = torch.tensor(tokens, dtype=torch.int32)
        self.tokens = tokens
        if isinstance(logprobs, list):
            logprobs = torch.tensor(logprobs, dtype=torch.float32)
        self.logprobs = logprobs
        self.reward = reward
        if isinstance(advantages, list):
            advantages = torch.tensor(advantages, dtype=torch.float32)
        self.advantages = advantages
        if isinstance(returns, list):
            returns = torch.tensor(returns, dtype=torch.float32)
        self.returns = returns
        self.experience_type = experience_type
        self.info = info or {}
        self.metrics = metrics or {}
        self.prompt_length = prompt_length
        self.response_text = response_text
        self.prompt_text = prompt_text
        if isinstance(action_mask, list):
            action_mask = torch.tensor(action_mask, dtype=torch.bool)
        self.action_mask = action_mask
        self.messages = messages
        self.tools = tools
        if isinstance(chosen, list):
            chosen = torch.tensor(chosen, dtype=torch.int32)
        self.chosen = chosen
        if isinstance(rejected, list):
            rejected = torch.tensor(rejected, dtype=torch.int32)
        self.rejected = rejected
        self.chosen_messages = chosen_messages
        self.rejected_messages = rejected_messages
        self.multi_modal_inputs = multi_modal_inputs
        if multi_modal_inputs is not None:
            self.multi_modal_inputs = {}
            for key, value in multi_modal_inputs.items():
                if not isinstance(value, Tensor):
                    self.multi_modal_inputs[key] = torch.tensor(value)
                else:
                    self.multi_modal_inputs[key] = value

        if not isinstance(self.tokens, Tensor):
            self.tokens = torch.tensor(self.tokens)
        if self.logprobs is not None and not isinstance(self.logprobs, Tensor):
            self.logprobs = torch.tensor(self.logprobs)
        if self.action_mask is not None and not isinstance(self.action_mask, Tensor):
            self.action_mask = torch.tensor(self.action_mask)
        if self.chosen is not None and not isinstance(self.chosen, Tensor):
            self.chosen = torch.tensor(self.chosen)
        if self.rejected is not None and not isinstance(self.rejected, Tensor):
            self.rejected = torch.tensor(self.rejected)

    def serialize(self) -> bytes:
        """Serialize the experience to bytes."""
        return pickle.dumps(self)

    @classmethod
    def deserialize(cls, data: bytes) -> Experience:
        return pickle.loads(data)

    def to_dict(self) -> dict:
        """Convert the experience to a dictionary."""
        res = {
            "eid": self.eid,
            "type": self.experience_type,
            "prompt_length": self.prompt_length,
            "response_length": len(self.tokens) - self.prompt_length,  # type: ignore [arg-type]
            "info": self.info,
            "metrics": self.metrics,
        }
        if self.prompt_text is not None:
            res["prompt_text"] = self.prompt_text
        if self.response_text is not None:
            res["response_text"] = self.response_text
        if self.messages is not None:
            res["messages"] = self.messages
        if self.tools is not None:
            res["tools"] = self.tools
        if self.chosen_messages is not None:
            res["chosen_messages"] = self.chosen_messages
        if self.rejected_messages is not None:
            res["rejected_messages"] = self.rejected_messages
        if self.reward is not None:
            res["reward"] = float(self.reward)
        return res

    @classmethod
    def gather(
        cls,
        experiences: List[Experience],
        pad_token_id: int = 0,
        custom_fields: Optional[List[CustomField]] = None,
    ) -> Experiences:
        if len(experiences) == 0:
            return empty_experiences(custom_fields)
        exp_type = experiences[0].experience_type
        if exp_type == "dpo":
            experiences = split_dpo_experience_to_single_turn(experiences)
        max_prompt_length = max([exp.prompt_length for exp in experiences])  # type: ignore [type-var]
        max_response_length = max([len(exp.tokens) - exp.prompt_length for exp in experiences])  # type: ignore [arg-type]
        eids = [exp.eid for exp in experiences]

        # Gather tokens
        tokens = gather_token_ids(experiences, max_prompt_length, max_response_length, pad_token_id)

        # Gather rewards
        if experiences[0].reward is not None:
            rewards = torch.tensor([exp.reward for exp in experiences], dtype=torch.float)
        else:
            rewards = None

        # gather action_masks
        action_masks = gather_action_masks(experiences, max_response_length)

        # gather attention_masks
        attention_masks = gather_attention_masks(
            experiences, max_prompt_length, max_response_length
        )

        # gather logprobs

        if all(exp.logprobs is not None for exp in experiences):
            logprobs = gather_logprobs(experiences, max_response_length)
        else:
            logprobs = None

        # gather advantages
        if all(exp.advantages is not None for exp in experiences):
            advantages = gather_advantages(experiences, max_response_length)
        else:
            advantages = None

        # gather returns
        if all(exp.returns is not None for exp in experiences):
            returns = gather_returns(experiences, max_response_length)
        else:
            returns = None

        # gather multi_modal_inputs
        if all(exp.multi_modal_inputs is not None for exp in experiences):
            multi_modal_inputs = gather_multi_modal_inputs(experiences)
        else:
            multi_modal_inputs = None

        exps = Experiences(
            eids=eids,
            tokens=tokens,
            rewards=rewards,
            advantages=advantages,
            returns=returns,
            attention_masks=attention_masks,
            action_masks=action_masks,
            prompt_length=max_prompt_length,
            logprobs=logprobs,
            multi_modal_inputs=multi_modal_inputs,
        )
        if custom_fields is not None:
            for custom_field in custom_fields:
                exps.custom_fields.append(custom_field.destination_field)
                setattr(
                    exps,
                    custom_field.destination_field,
                    torch.tensor(
                        [exp.info[custom_field.source_field] for exp in experiences],
                        dtype=custom_field.data_type,
                    ),
                )
        return exps


def split_dpo_experience_to_single_turn(experiences: List[Experience]) -> List[Experience]:
    single_turn_experiences = []
    for exp in experiences:
        single_turn_experiences.append(
            Experience(
                eid=EID(
                    batch=exp.eid.batch,
                    task=exp.eid.task,
                    step=exp.eid.step,
                    run=exp.eid.run,
                ),
                tokens=torch.cat([exp.tokens, exp.chosen]),
                reward=exp.reward,
                info=exp.info,
                metrics=exp.metrics,
                prompt_length=len(exp.tokens),  # type: ignore [arg-type]
                prompt_text=exp.prompt_text,
                messages=exp.chosen_messages,
            )
        )
        single_turn_experiences.append(
            Experience(
                eid=EID(
                    batch=exp.eid.batch,
                    task=exp.eid.task,
                    step=exp.eid.step,
                    run=exp.eid.run,
                ),
                tokens=torch.cat([exp.tokens, exp.rejected]),
                reward=exp.reward,
                info=exp.info,
                metrics=exp.metrics,
                prompt_length=len(exp.tokens),  # type: ignore [arg-type]
                prompt_text=exp.prompt_text,
                messages=exp.rejected_messages,
            )
        )
    return single_turn_experiences


@dataclass
class Experiences:
    """A container for a batch of experiences, for high performance communication usage.

    Example:

        >>>             |<- prompt_length ->|               |
        >>> tokens: ('P' represents prompt, 'O' represents output)
        >>> exp1:       |........PPPPPPPPPPP|OOOOOOOOOO.....|
        >>> exp2:       |......PPPPPPPPPPPPP|OOOOOOO........|
        >>>
        >>> attention_masks: ('.' represents False and '1' represents True)
        >>> exp1:       |........11111111111|1111111111.....|
        >>> exp2:       |......1111111111111|1111111........|
    """

    eids: List[EID]  # Experience IDs of each experience in the batch
    tokens: Tensor  # [batch_size, seq_length]
    rewards: Tensor  # [batch_size]
    advantages: Optional[Tensor]  # [batch_size, response_length]
    returns: Optional[Tensor]  # [batch_size, response_length]
    attention_masks: Tensor  # [batch_size, sequence_length]
    action_masks: Optional[Tensor]  # [batch_size, response_length]
    prompt_length: int
    logprobs: Optional[Tensor]  # [batch_size, response_length]
    multi_modal_inputs: Optional[Any]
    custom_fields: List[str] = field(
        default_factory=list
    )  # Custom fields to include in the gathered experiences

    @property
    def batch_size(self) -> int:
        """Get the batch size."""
        return self.tokens.size(0)

    @classmethod
    def gather_experiences(
        cls,
        experiences: list[Experience],
        pad_token_id: int = 0,
        custom_fields: Optional[List[CustomField]] = None,
    ) -> Experiences:
        """Gather a batch of experiences from a list of experiences.

        This method will automatically pad the `tokens` and `logprobs` of input experiences to the same length.

        Args:
            experiences (list[Experience]): A list of experiences to gather.
            pad_token_id (int): The token ID to use for padding. Default is 0.
            custom_fields (Optional[List[CustomField]]): Custom fields to include in the gathered experiences.
        """
        if len(experiences) == 0:
            return empty_experiences(custom_fields)
        return experiences[0].__class__.gather(
            experiences, pad_token_id=pad_token_id, custom_fields=custom_fields
        )


def empty_experiences(custom_fields: Optional[List[CustomField]]) -> Experiences:
    exps = Experiences(
        tokens=torch.empty(0, dtype=torch.int32),
        rewards=torch.empty(0, dtype=torch.float32),
        advantages=torch.empty(0, dtype=torch.float32),
        returns=torch.empty(0, dtype=torch.float32),
        attention_masks=torch.empty(0, dtype=torch.bool),
        action_masks=torch.empty(0, dtype=torch.bool),
        logprobs=torch.empty(0, dtype=torch.float32),
        prompt_length=torch.empty(0, dtype=torch.int32),
        eids=[],
        multi_modal_inputs=torch.empty(0, dtype=torch.float32),
    )
    if custom_fields is not None:
        for custom_field in custom_fields:
            exps.custom_fields.append(custom_field.destination_field)
            setattr(
                exps, custom_field.destination_field, torch.empty(0, dtype=custom_field.data_type)
            )
    return exps


def gather_token_ids(
    experiences, max_prompt_length: int, max_response_length: int, pad_token_id: int
) -> Tensor:
    token_ids_dtype = experiences[0].tokens.dtype
    return torch.stack(
        [
            torch.cat(
                [
                    torch.full(
                        (max_prompt_length - exp.prompt_length,),
                        pad_token_id,
                        dtype=token_ids_dtype,
                    ),
                    exp.tokens,
                    torch.full(
                        (max_response_length + exp.prompt_length - len(exp.tokens),),
                        pad_token_id,
                        dtype=token_ids_dtype,
                    ),
                ]
            )
            for exp in experiences
        ]
    )


def gather_action_masks(experiences, max_response_length: int) -> Tensor:
    return torch.stack(
        [
            torch.cat(
                [
                    exp.action_mask,
                    torch.full(
                        (max_response_length - len(exp.action_mask),),
                        0,
                        dtype=torch.bool,
                    ),
                ]
            )
            for exp in experiences
        ]
    )


def gather_attention_masks(experiences, max_prompt_length: int, max_response_length: int) -> Tensor:
    attention_masks = torch.zeros(
        (len(experiences), max_prompt_length + max_response_length), dtype=torch.bool
    )

    for i, exp in enumerate(experiences):
        start = max_prompt_length - exp.prompt_length
        end = start + len(exp.tokens)
        attention_masks[i, start:end] = 1

    return attention_masks


def gather_logprobs(experiences, max_response_length: int) -> Tensor:
    logprob_dtype = experiences[0].logprobs.dtype  # type: ignore [union-attr]
    return torch.stack(
        [
            torch.cat(
                [
                    exp.logprobs,
                    torch.full(
                        (max_response_length - len(exp.logprobs),),
                        0.0,
                        dtype=logprob_dtype,
                    ),
                ]
            )
            for exp in experiences
        ]
    )


def gather_advantages(experiences, max_response_length: int) -> Optional[Tensor]:
    if experiences[0].advantages is None:
        return None
    advantages_dtype = experiences[0].advantages.dtype
    return torch.stack(
        [
            torch.cat(
                [
                    exp.advantages,
                    torch.full(
                        (max_response_length - len(exp.advantages),),
                        0.0,
                        dtype=advantages_dtype,
                    ),
                ]
            )
            for exp in experiences
        ]
    )


def gather_returns(experiences, max_response_length: int) -> Optional[dict[str, List[Tensor]]]:
    if experiences[0].returns is None:
        return None
    returns_dtype = experiences[0].returns.dtype
    return torch.stack(
        [
            torch.cat(
                [
                    exp.returns,
                    torch.full(
                        (max_response_length - len(exp.returns),),
                        0.0,
                        dtype=returns_dtype,
                    ),
                ]
            )
            for exp in experiences
        ]
    )


def gather_multi_modal_inputs(experiences) -> Dict[str, Tensor]:
    keys = experiences[0].multi_modal_inputs.keys()
    return {key: [exp.multi_modal_inputs[key] for exp in experiences] for key in keys}


def group_by(
    experiences: List[Experience], id_type: Literal["task", "run", "step"]
) -> Dict[str, List[Experience]]:
    """Group experiences by ID."""
    if id_type == "task":
        id_type = "tid"
    elif id_type == "run":
        id_type = "rid"
    elif id_type == "step":
        id_type = "sid"
    else:
        raise ValueError(f"Unknown id_type: {id_type}")
    grouped = {}
    for exp in experiences:
        group_id = getattr(exp.eid, id_type)
        if group_id not in grouped:
            grouped[group_id] = []
        grouped[group_id].append(exp)
    return grouped


def to_hf_datasets(experiences: list[Experience]) -> Dataset:
    """
    Convert a list of Experience objects to a HuggingFace Dataset,
    preserving all fields.
    """
    return Dataset.from_list([asdict(exp) for exp in experiences])


def from_hf_datasets(dataset: Dataset) -> List[Experience]:
    """
    Convert a HuggingFace Dataset back to a list of Experience objects.
    """

    def dict_to_dataclass(cls, d):
        valid_keys = {f.name for f in fields(cls)}
        filtered = {k: v for k, v in d.items() if k in valid_keys}
        return cls(**filtered)

    experiences = [dict_to_dataclass(Experience, row) for row in dataset.to_list()]

    return experiences
